This paper describes a framework to combine LoRA modules to achieve adaptable performance on unseen tasks. See this notebook from Crumb, which predates it.
📝 Paper: https://arxiv.org/abs/2307.13269 💻 GitHub: https://github.com/sail-sg/lorahub
Problem statement
This paper is focused on pre-trained encoder-decoder and decoder-only transformer models.
The goal is to improve cross-task generalization. This term regroups zero-shot learning and few-shot learning. The idea is that you could [[LoRA – Low-Rank Adaptation of Large Language Models|LoRA]] tune a model using the examples from few-shot learning, but it is “inefficient, time-consuming, and unstable” when the number of examples is small.
Methodology
It requires several LoRA modules (r=16), which are first trained on a variety of upstream tasks (e.g., boolean expressions).
LoraHub has two phases:
- Compose, where all available LoRA modules (they can be filtered) are merged into one module using weights (positive or negative). \hat{m} = (w_1 A_1 + w_2 A_2 + \dots + w_N A_N) (w_1 B_1 + w_2 B_2 + \dots + w_N B_N).
- Adapt: the assembled LoRA module is used in combination with the base LLM. Its performance is assessed on few-shot examples to select the best weights w_i. The objective is to minimize the cross-entropy loss with L1 regularization. This is done with an algorithm from Shiwa, the Covariance Matrix Adaptive Evolution Strategies (CMA-ES). It is implemented with the Nevergrad optimization library.
I’m skeptical about this optimization technique. I would like to know how it compares to a centroid-based approach for instance.
Evaluation
They use Flan-T5 large (783M parameters) as a base LLM. They trained a LoRA module for each of the 200 distinct tasks from FLAN_2022 and released them on Hugging Face. Due to this high number, they pre-filter them and randomly select 20 modules for each “experimental sequence.”
This pre-filtering could be easily improved.
Finally, LoraHub is evaluated on the BBH benchmark (Exact Match as evaluation metric).
It is not as good as ICL, but consumes 5 times less tokens.